// Maths for Game Programming
// Quaternion class - to be completed in lesson
// Jason Colman Qantm London 2009

#include <assert.h>
#ifdef WIN32
#define _USE_MATH_DEFINES
#endif
#include <math.h>
#include "Quaternion.h"

// TODO Default ctor - set to identity
Quaternion::Quaternion()
{
	m_w = 1;
	m_x = 0;
	m_y = 0;
	m_z = 0;
}
// TODO Ctor - set from axis and angle
Quaternion::Quaternion(float radians, const Vec3f& axis)
{
	normalisedAxis = axis;
	normalisedAxis.Normalise();
	sinV = sin(radians * 0.5f);

	m_w = cos(radians * 0.5f);
	m_x = axis.x * sinV;
	m_y = axis.y * sinV;
	m_z = axis.z * sinV;
}

Quaternion::Quaternion(float radians, float x, float y, float z)
{
	normalisedAxis = Vec3f(x, y, z);
	normalisedAxis.Normalise();
	sinV = sin(radians * 0.5f);


	m_w = cos(radians * 0.5f);
	m_x = x * sinV;
	m_y = y * sinV;
	m_z = z * sinV;
}

// TODO Normalise()
void Quaternion::Normalise()
{
	float sq = (m_w * m_w) + (m_x * m_x) + (m_y * m_y) + (m_z * m_z);

	float oneOverSq = 1.0f / sqrt(sq);

	m_w *= oneOverSq;
	m_x *= oneOverSq;
	m_y *= oneOverSq;
	m_z *= oneOverSq;
}



void Quaternion::ToMatrix(float pMatrix[16]) const
{
	assert(pMatrix);

	pMatrix[ 0] = 1.0f - 2.0f * ( m_y * m_y + m_z * m_z ); 
	pMatrix[ 1] = 2.0f * (m_x * m_y + m_z * m_w);
	pMatrix[ 2] = 2.0f * (m_x * m_z - m_y * m_w);
	pMatrix[ 3] = 0.0f;  

	pMatrix[ 4] = 2.0f * ( m_x * m_y - m_z * m_w );  
	pMatrix[ 5] = 1.0f - 2.0f * ( m_x * m_x + m_z * m_z ); 
	pMatrix[ 6] = 2.0f * (m_z * m_y + m_x * m_w );  
	pMatrix[ 7] = 0.0f;  

	pMatrix[ 8] = 2.0f * ( m_x * m_z + m_y * m_w );
	pMatrix[ 9] = 2.0f * ( m_y * m_z - m_x * m_w );
	pMatrix[10] = 1.0f - 2.0f * ( m_x * m_x + m_y * m_y );  
	pMatrix[11] = 0.0f;  

	pMatrix[12] = 0;  
	pMatrix[13] = 0;  
	pMatrix[14] = 0;  
	pMatrix[15] = 1.0f;
}

Quaternion operator*(const Quaternion& p, const Quaternion& q)
{
	Quaternion r;

	r.m_w = p.m_w * q.m_w - p.m_x * q.m_x - p.m_y * q.m_y - p.m_z * q.m_z;
	r.m_x = p.m_w * q.m_x + p.m_x * q.m_w + p.m_y * q.m_z - p.m_z * q.m_y;
	r.m_y = p.m_w * q.m_y + p.m_y * q.m_w + p.m_z * q.m_x - p.m_x * q.m_z;
	r.m_z = p.m_w * q.m_z + p.m_z * q.m_w + p.m_x * q.m_y - p.m_y * q.m_x;

	//r.m_w += 0.1f;

	return r;
}

Quaternion Slerp(const Quaternion& from, const Quaternion& to, float t)
{
	float to1[4];
	double omega, cosom, sinom, scale0, scale1;

	// calc cosine
	cosom = from.m_x * to.m_x + from.m_y * to.m_y + from.m_z * to.m_z
		+ from.m_w * to.m_w;

	// adjust signs (if necessary)
	if (cosom < 0)
	{ 
		cosom = -cosom; 
		to1[0] = - to.m_x;
		to1[1] = - to.m_y;
		to1[2] = - to.m_z;
		to1[3] = - to.m_w;
	} 
	else  
	{
		to1[0] = to.m_x;
		to1[1] = to.m_y;
		to1[2] = to.m_z;
		to1[3] = to.m_w;
	}

	// calculate coefficients

	static const double DELTA = 0.001;
	if ( (1.0 - cosom) > DELTA ) 
	{
		// standard case (slerp)
		omega = acos(cosom);
		sinom = sin(omega);
		scale0 = sin((1.0 - t) * omega) / sinom;
		scale1 = sin(t * omega) / sinom;
	} 
	else 
	{        
		// "from" and "to" quaternions are very close 
		//  ... so we can do a linear interpolation
		scale0 = 1.0 - t;
		scale1 = t;
	}

	Quaternion res;

	// calculate final values
	res.m_x = (float)(scale0 * (double)from.m_x + scale1 * (double)to1[0]);
	res.m_y = (float)(scale0 * (double)from.m_y + scale1 * (double)to1[1]);
	res.m_z = (float)(scale0 * (double)from.m_z + scale1 * (double)to1[2]);
	res.m_w = (float)(scale0 * (double)from.m_w + scale1 * (double)to1[3]);

	return res;
}

Quaternion Quaternion::Conjugate()
{
	return Quaternion(m_w, -m_x, -m_y, -m_z);
}

Vec3f Quaternion::RotateOOBB(float angle, const Vec3f& v)
{
	//Quaternion vAfter = *this * Quaternion(angle, v.x, v.y, v.z) * Conjugate();
	//return Vec3f(vAfter.m_x, vAfter.m_y, vAfter.m_z);

	//assert(fabs(v.Length() - 1.0f) < 0.01f);

	//float halfAngle = angle * 0.5f;
	//float halfSin = sin(halfAngle);

	//m_w = cos(halfAngle);
	//m_x = v.x * halfSin;
	//m_y = v.y * halfSin;
	//m_z = v.z * halfSin;
	//return Vec3f(m_x, m_y, m_z);

	m_w = cos(angle * 0.5f);

	float sinAngleSq = 1.0f - m_w * m_w;
	if (sinAngleSq <= 0.0f)
	{
		return v;
	}

	float oneOverSinAngle = 1.0f / sqrt(sinAngleSq);
	return Vec3f(m_x * oneOverSinAngle, m_y * oneOverSinAngle, m_z * oneOverSinAngle);
}